#include <sys/time.h>
#include <stdio.h>
#include <climits>
#include "cell.h"
#include "path.hpp"
#include "trajectory.hpp"

#include "shake.h"
//#include "rattle.h"
#include "comm.h"
#include "force_field.h"
#include "io_dcd.h"
#include "verlet.h"
#include "timer.h"
#include "bonded.h"
#include "listed.hpp"
#include "hmini.hpp"

//#define DEBUG_X
DEF_TIMER(FORCE, "verlet/force computation")
DEF_TIMER(INITIAL_INTEGRATE, "verlet/initial integrate")
DEF_TIMER(FINAL_INTEGRATE, "verlet/final integrate")
DEF_TIMER(CONSTRAINT, "verlet/constraint")
DEF_TIMER(CELLS, "verlet/cells")
#ifdef __sw__
#include "sw/swarch.h"
#include <qthread.h>
#endif
//  #define DEBUG_X
int initial_integrate_verlet(cellgrid_t *grid, real dt, real vscale) {
#ifdef __sw__
  return initial_integrate_verlet_sw(grid, dt, vscale);
#endif
#ifdef DEBUG_X
  vec<real> maxdx = {0, 0, 0}, mindx = {0, 0, 0};
#endif
  //for (auto [x, v, f, rmass] : grid.atoms(ATOM_RW(ATOM_X, ATOM_V), ATOM_RO(ATOM_F, ATOM_RMASS))){
  // v += f * rmass * dt * 0.5 / (48*48);
  // x += v * dt;
  //}
  FOREACH_LOCAL_CELL(grid, ii, jj, kk, cell) {
    for (int i = 0; i < cell->natom; i++) {
      cell->v[i] += cell->f[i] * (cell->rmass[i] * dt * 0.5 / 48.88821291 / 48.88821291);
      cell->x[i] += cell->v[i] * dt;
    }
  }
  return cell_check(grid);
}
void final_integrate_verlet(cellgrid_t *grid, real dt) {
#ifdef __sw__
  final_integrate_verlet_sw(grid, dt);
  return;
#endif
  FOREACH_LOCAL_CELL(grid, ii, jj, kk, cell) {
    for (int i = 0; i < cell->natom; i++) {
      cell->v[i] += cell->f[i] * (cell->rmass[i] * dt * 0.5 / 48.88821291 / 48.88821291);
    }
  }
}
real get_kin(cellgrid_t *grid) {
  real kin = 0;
  FOREACH_LOCAL_CELL(grid, ii, jj, kk, cell) {
    for (int i = 0; i < cell->natom; i++) {
      kin += 0.5 * cell->mass[i] * cell->v[i].norm2();
    }
  }
  return kin;
}

DEF_TIMER(CELL_EXPORT, "cell_export")
DEF_TIMER(CELL_IMPORT, "cell_import")
DEF_TIMER(FWD_MOST, "forward most")
DEF_TIMER(FWD_EXPORT, "forward export")
void update_cell_grid(ff_t *ff, mpp_t *mpp) {
  timer_start(CELL_EXPORT);
  cell_export(ff->grid);
  timer_stop(CELL_EXPORT);
  timer_start(FWD_EXPORT);
  forward_comm_export_list(ff->grid, mpp);
  timer_stop(FWD_EXPORT);
  timer_start(CELL_IMPORT);
  cell_import(ff->grid);
  timer_stop(CELL_IMPORT);
  timer_start(FWD_MOST);
  forward_comm_most(ff->grid, mpp);
  timer_stop(FWD_MOST);
  if (ff->present & M_BOND) {
    //find_bonds(ff->grid);
    //forward_comm_bond(ff->grid, mpp);
  }
  ff->grid->topo->update(mpp, ff->grid);

#ifdef __sw__
  build_pack_params_sw((sw_archdata_t*)ff->grid->arch_data, ff->grid);
#endif
}
real get_ke(cellgrid_t *grid, mpp_t *mpp, esmd_config_t *conf) {
  real mvv_sum = 0;
  FOREACH_LOCAL_CELL(grid, cx, cy, cz, cell) {
    for (int i = 0; i < cell->natom; i++) {
      mvv_sum += cell->v[i].norm2() * cell->mass[i];
    }
  }
  comm_allreduce(&mvv_sum, 1, mpi_real, MPI_SUM, mpp);
  return mvv_sum * conf->mvv2e;
}
#ifdef __sw__
typedef struct{
  mdstat_t *stat;
  mpp_t *mpp;
  int istep;
} reduce_stat_param_t;
void verlet_reduce_print_stat(reduce_stat_param_t *pm){
  comm_allreduce_stat(pm->stat, pm->mpp);
  if (pm->mpp->pid == 0) {
    real te = pm->stat->ecoul + pm->stat->evdwl + pm->stat->ebond + pm->stat->eangle + pm->stat->etori + pm->stat->eimpr;
    printf("istep = %d: %f %f %f %f %f %f %f\n", pm->istep, te, pm->stat->ecoul, pm->stat->evdwl, pm->stat->ebond, pm->stat->eangle, pm->stat->etori, pm->stat->eimpr);
  }
}
#endif
DEF_TIMER(LOOP, "loop");
DEF_TIMER(TEMP_INTEGRATE, "temp integrate")
DEF_TIMER(FORWARD_X, "forward x")
DEF_TIMER(REVERSE_F, "reverse f")
DEF_TIMER (REST, "write restart")
// extern listed_force_grid<harmonic_bond_param, MAX_BONDED_CELL> listed_grid;
void integrate_verlet(long start_step, ff_t *ff, rigid_t *rigid, mpp_t *mpp, esmd_config_t *conf) {
  /*get a0 first*/
  // FILE *dump = fopen("esmd.dump", "w");
  char restart_prefix[PATH_MAX];
  if (conf->restart)
    expand_path(restart_prefix, conf->root, "%s", conf->restart);
  char trajpath[PATH_MAX];
  if (conf->trajectory)
    expand_path(trajpath, conf->root, "%s-%06d.estraj", conf->trajectory, mpp->pid);
  if (conf->restart)
    write_restart(restart_prefix, 0, 0, mpp, ff->grid);
  // dcd_file_t dcd;
  // dcd_header_init(&dcd, "esmd.dcd", ff->grid, mpp, conf->dt);
  // dcd.header.framestepcnt = conf->trajfreq;
  // check_nan(ff);
  timer_start(REST);
  // write_restart("restart1", 0, 0, mpp, ff->grid);
  // read_restart("restart-0-000000.esrst", ff->grid);
  timer_stop(REST);
  timer_start(FORCE);
  force(ff);
  timer_stop(FORCE);
  // check_nan(ff);
#ifdef __sw__
  mdstat_t last_stat;
  reduce_stat_param_t reduce_param = {&last_stat, mpp};
  build_pack_params_sw((sw_archdata_t*)ff->grid->arch_data, ff->grid);
#endif
  reverse_comm_f(ff->grid, mpp);
  
  real ftm2v = 1 / 48.88821291 / 48.88821201;
  if (rigid->setup)
    rigid->setup(ff->grid, mpp, conf->dt, ftm2v);
  // check_nan(ff);
  int nbuild = 0;
  // mdstat_t ff->stat;
  // start_prof();
  MPI_Barrier(MPI_COMM_WORLD);
  timer_start(LOOP);
  real vscale = 1.0;
  real ke_current = 0;
  if (ff->present & M_TEMP_INTEGRATE) {
    ke_current = get_ke(ff->grid, mpp, conf);
  }
  int itape = 0;
  long i;
  for (i = start_step; i < conf->nsteps; i++) {
    if (conf->restart && (i + 1) % conf->restfreq == 0) {
       itape ^= 1;
       write_restart(restart_prefix, itape, i + 1, mpp, ff->grid);
    }
    if (conf->trajectory && i % conf->trajfreq == 0) {
      write_trajectory(trajpath, i, ff->grid);
      // dcd_write_frame(&dcd, ff->grid, mpp);
    }
    if (ff->present & M_TEMP_INTEGRATE) {
      timer_start(TEMP_INTEGRATE);
      real vfactor = ff->temp_integrate(ff->grid, mpp, ff->tstat, i / conf->nsteps, ke_current * vscale * vscale);
      vscale *= vfactor;
      // printf("%f\n", vfactor);
      timer_stop(TEMP_INTEGRATE);
    }
    timer_start(INITIAL_INTEGRATE);
    int need_rebuild = initial_integrate_verlet(ff->grid, conf->dt, vscale);
    vscale = 1.0;
    #ifndef __sw__
    ff->stat.need_rebuild = need_rebuild;
    
    comm_reduce_stat(&ff->stat, mpp);
    if (mpp->pid == 0) {
      real te = ff->stat.ecoul + ff->stat.evdwl + ff->stat.ebond + ff->stat.eangle + ff->stat.etori + ff->stat.eimpr;
      printf("istep = %d: %f %f %f %f %f %f %f\n", i, te, ff->stat.ecoul, ff->stat.evdwl, ff->stat.ebond, ff->stat.eangle, ff->stat.etori, ff->stat.eimpr);
    }
    timer_start(CELLS);
    if (ff->stat.need_rebuild) {
      update_cell_grid(ff, mpp);
      nbuild++;
    }
    // int built = post_coordinate_update(ff, mpp);
    timer_stop(CELLS);

    #else
    memcpy(&last_stat, &ff->stat, sizeof(last_stat));
    last_stat.need_rebuild = need_rebuild;
    qthread_pend_host(verlet_reduce_print_stat, &reduce_param);
    reduce_param.istep = i;
    // qthread_flush_host();
    // if (last_stat.need_rebuild > 0) {
    //   timer_start(CELLS);
    //   update_cell_grid(ff, mpp);
    //   timer_stop(CELLS);
    //   if (i - ((sw_archdata_t*)ff->grid->arch_data)->last_lbal * conf->dt > 100) {
    //     ((sw_archdata_t*)ff->grid->arch_data)->last_lbal = i;
    //     do_lbal(ff->grid);
    //   }
    // }

    // printf("pend: %d\n", i);
    #endif
    timer_stop(INITIAL_INTEGRATE);

    //check_nan(ff);
    // if (i == 0 || built) {
    //   ff->grid->updated = 1;
    // }
    timer_start(FORWARD_X);
    forward_comm_x(ff->grid, mpp);
    timer_stop(FORWARD_X);
    timer_start(FORCE);
    force(ff);
    #ifdef __sw__
    qthread_flush_host();
    #endif
    timer_stop(FORCE);
    #ifdef __sw__
    if (last_stat.need_rebuild > 0) {
      // puts("rebuild");
      timer_start(CELLS);
      update_cell_grid(ff, mpp);
      // puts("rebuild");
      if (i - ((sw_archdata_t*)ff->grid->arch_data)->last_lbal * conf->dt > 100) {
        ((sw_archdata_t*)ff->grid->arch_data)->last_lbal = i;
        do_lbal(ff->grid);
      }
      timer_stop(CELLS);
      timer_start(FORCE);
      force(ff);
      timer_stop(FORCE);

      need_rebuild = 0;
    }
    #endif
    timer_start(REVERSE_F);
    reverse_comm_f(ff->grid, mpp);
    timer_stop(REVERSE_F);
    if (rigid->mask & RIGID_POST_FORCE) {
      timer_start(CONSTRAINT);
      rigid->post_force(ff->grid, mpp, conf->dt, ftm2v);
      // check_nan(ff);
      timer_stop(CONSTRAINT);
    }

    //check_nan(ff);
    // shake_post_force(ff->grid, mpp, conf->dt, ftm2v);
    timer_start(FINAL_INTEGRATE);
    final_integrate_verlet(ff->grid, conf->dt);
    timer_stop(FINAL_INTEGRATE);
    if (rigid->mask & RIGID_FINAL_INTEGRATE) {
      timer_start(CONSTRAINT);
      rigid->final_integrate(ff->grid, mpp, conf->dt, ftm2v);
      timer_stop(CONSTRAINT);
    }
    if (ff->present & M_TEMP_INTEGRATE) {
      timer_start(TEMP_INTEGRATE);
      ke_current = get_ke(ff->grid, mpp, conf);
      real vfactor = ff->temp_integrate(ff->grid, mpp, ff->tstat, i / conf->nsteps, ke_current * vscale * vscale);
      vscale *= vfactor;
      // printf("%f\n", vfactor);
      timer_stop(TEMP_INTEGRATE);
    }
    // check_nan(ff);
    ff->grid->updated = 0;
  }
  timer_stop(LOOP);
  MPI_Barrier(MPI_COMM_WORLD);
  if (vscale != 1.0) {
    /* set dt to zero, that is, scale velocity only */
    initial_integrate_verlet(ff->grid, 0, vscale);
  }
  // pause_prof();
  comm_reduce_stat(&ff->stat, mpp);
  if (mpp->pid == 0) {
    printf("istep = %d: %f %f %f %f %f %f\n", i, ff->stat.ecoul, ff->stat.evdwl, ff->stat.ebond, ff->stat.eangle, ff->stat.etori, ff->stat.eimpr);
    printf("nsteps=%d, %d cell list rebuilds\n", conf->nsteps, nbuild);
  }
  // dcd_close(&dcd, mpp);
  //fclose(dump);
}
void check_nan(ff_t *ff) {
  int has_nan_x = 0, has_nan_v = 0, has_nan_f = 0;
  FOREACH_LOCAL_CELL(ff->grid, ii, jj, kk, cell) {
    for (int i = 0; i < cell->natom; i++) {
      if (!is_finite_real(cell->x[i].x))
        has_nan_x = 1;
      if (!is_finite_real(cell->x[i].y))
        has_nan_x = 1;
      if (!is_finite_real(cell->x[i].z))
        has_nan_x = 1;
      if (!is_finite_real(cell->v[i].x))
        has_nan_v = 1;
      if (!is_finite_real(cell->v[i].y))
        has_nan_v = 1;
      if (!is_finite_real(cell->v[i].z))
        has_nan_v = 1;
      if (!is_finite_real(cell->f[i].x))
        has_nan_f = 1;
      if (!is_finite_real(cell->f[i].y))
        has_nan_f = 1;
      if (!is_finite_real(cell->f[i].z))
        has_nan_f = 1;
    }
  }
  printf("%d %d %d\n", has_nan_x, has_nan_v, has_nan_f);
}
